# -*- coding: utf-8 -*-
"""
Created on Thu Mar 30 14:52:51 2023

@author: lv
"""
import sys
import os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import torch
from utils.SelfAttentionEncoder import SelfAttentionEncoder

from transformers import BertTokenizer
from utils.DigitEmbedding import DigitEmbedding

tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')

max_seq_length = 500
batch_size = 1
num_digits = 6

# 初始化 SelfAttentionEncoder 模型
model = SelfAttentionEncoder(max_seq_length=max_seq_length,
                input_max_num_sequences=batch_size,
                num_digits=num_digits)

# 输入两句话
text1 = "你好，自然语言处理"
text2 = "这是一条测试SelfAttentionEncoder的语句"
input_texts = [text1, text2]

# 转换 tokenizer 编码为数字编码
digit_embedding = DigitEmbedding(num_digits, max_seq_length)

def encode_text(text: str) -> torch.Tensor:
    #print(text)
    encoded_text = tokenizer.encode(text, add_special_tokens=True)
    padding = [0] * (max_seq_length - len(encoded_text))
    encoded_text += padding
    encoded_text = digit_embedding.encode(encoded_text)
    #print(encoded_text)
    return torch.tensor(encoded_text).view(-1,num_digits)

# 对两句话进行编码
input_ids = []
for text in input_texts:
    encoded_text = encode_text(text)
    input_ids.append(encoded_text)

# 对输入进行编码
output = model(input_ids)

# 对编码结果进行解码
decoded_texts = []
for i in range(output.shape[0]):
    decoded_text = tokenizer.decode(output[i], skip_special_tokens=True)
    decoded_texts.append(decoded_text)

# 输出解码结果
print(decoded_texts)